#install.packages("mvtnorm")
library(mvtnorm)
EMalgo <- function(X,K,n.iter=50,plot=FALSE,...){
# Initialize
n = nrow(X); p = ncol(X)
P = matrix(NA,n,K)
prop = rep(1/K,K)
mu = rmvnorm(K,mean = colMeans(X))
Sigma = array(0,dim = c(K,p,p))
for (k in 1:K) Sigma[k,,] = diag(p)
# The EM loop
for (it in 1:n.iter){ cat('.')
# E step
for (k in 1:K) P[,k] = prop[k] * dmvnorm(X,mu[k,],Sigma[k,,])
P = P / rowSums(P) %*% matrix(1,1,K) # Normalization
# optionnal plot
if (plot){
plot(X,col=max.col(P)+1)
points(mu,col=c(2,3),pch=19,cex=3)
Sys.sleep(0.1)
}
# M step
for (k in 1:K){
prop[k] = sum(P[,k]) / n
mu[k,] = 1/sum(P[,k]) * colSums(P[,k]%*%matrix(1,1,p) * X)
Ak = P[,k]%*%matrix(1,1,p) * (X - matrix(1,n,1)%*%mu[k,])
Bk = X - matrix(1,n,1) %*% mu[k,]
Sigma[k,,] = 1/sum(P[,k]) * t(Ak) %*% Bk
}
}
cat('\n')
plot(X,col=max.col(P)+1); points(mu,col=c(2,3),pch=19,cex=3)
# Calculation of BIC
loglik = 0
for (k in 1:K) {
Xk = X - matrix(1,n,1) %*% mu[k,]
loglik = loglik + sum(log(prop[k]) - p/2*log(pi)
- 1/2 * log(det(Sigma[k,,]))
- 1/2 * Xk %*% solve(Sigma[k,,]) %*% t(Xk))
}
bic = loglik - 1/2 * (K-1 + K*p*(p+1)) * log(n)
list(P = P, prop = prop, mu = mu, Sigma = Sigma,
bic = bic, loglik = loglik)
}
Let’s start with an easy situation:
X = rbind(rmvnorm(100,c(0,0),0.1*diag(2)),
rmvnorm(200,c(-2,2),0.2*diag(2)))
plot(X)
out = EMalgo(X,2)
## ..................................................
out$mu
## [,1] [,2]
## [1,] -1.993723983 2.03243307
## [2,] -0.001575338 -0.00781512
plot(X)
points(out$mu,col=c(2,3),pch=19,cex=3)
Here, we can see the effect of each iteration on the inference of the means and the estimation of cluster memberships:
out = EMalgo(X,2,plot = TRUE,n.iter = 10)
## .
## .
## .
## .
## .
## .
## .
## .
## .
## .
Let’s now consider a more difficult problem:
X = rbind(rmvnorm(100,c(0,0),0.8*diag(2)),
rmvnorm(200,c(-2,2),0.8*diag(2)))
plot(X)
out = EMalgo(X,2)
## ..................................................
out$mu
## [,1] [,2]
## [1,] -1.8202362 1.89251405
## [2,] 0.2863645 0.01007436
Let’s now consider the problem of choosing \(K\):
X = rbind(rmvnorm(100,c(0,0),0.4*diag(2)),
rmvnorm(200,c(-2,2),0.4*diag(2)))
plot(X)
par(mfrow = c(2,3))
for (k in 1:6) EMalgo(X,k,n.iter = 50)
## ..................................................
## ..................................................
## ..................................................
## ..................................................
## ..................................................
## ..................................................
Let’s try now to use BIC to select the best K:
bic = rep(NA,6)
for (k in 1:6) bic[k] = EMalgo(X,k)$bic
## ..................................................
## ..................................................
## ..................................................
## ..................................................
## ..................................................
## ..................................................
plot(bic,type='b')